<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
  <meta charset="utf-8">
  <meta http-equiv="X-UA-Compatible" content="IE=edge">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  
  <link rel="shortcut icon" href="../../img/favicon.ico">
  <title>函数式 API 指引 - Keras 中文文档</title>
  <link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>

  <link rel="stylesheet" href="../../css/theme.css" type="text/css" />
  <link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
  <link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
  
  <script>
    // Current page data
    var mkdocs_page_name = "\u51fd\u6570\u5f0f API \u6307\u5f15";
    var mkdocs_page_input_path = "getting-started/functional-api-guide.md";
    var mkdocs_page_url = "/zh/getting-started/functional-api-guide/";
  </script>
  
  <script src="../../js/jquery-2.1.1.min.js" defer></script>
  <script src="../../js/modernizr-2.8.3.min.js" defer></script>
  <script src="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/highlight.min.js"></script>
  <script>hljs.initHighlightingOnLoad();</script> 
  
  <script>
      (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
      (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
      m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
      })(window,document,'script','https://www.google-analytics.com/analytics.js','ga');

      ga('create', 'UA-61785484-1', 'keras.io');
      ga('send', 'pageview');
  </script>
  
</head>

<body class="wy-body-for-nav" role="document">

  <div class="wy-grid-for-nav">

    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav">
      <div class="wy-side-nav-search">
        <a href="../.." class="icon icon-home"> Keras 中文文档</a>
        <div role="search">
  <form id ="rtd-search-form" class="wy-form" action="../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" title="Type search term here" />
  </form>
</div>
      </div>

      <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
	<ul class="current">
	  
          
            <li class="toctree-l1">
		
    <a class="" href="../..">主页</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../why-use-keras/">为什么选择 Keras?</a>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">快速开始</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../sequential-model-guide/">Sequential 顺序模型指引</a>
                </li>
                <li class=" current">
                    
    <a class="current" href="./">函数式 API 指引</a>
    <ul class="subnav">
            
    <li class="toctree-l3"><a href="#keras-api">开始使用 Keras 函数式 API</a></li>
    
        <ul>
        
            <li><a class="toctree-l4" href="#_1">例一：全连接网络</a></li>
        
            <li><a class="toctree-l4" href="#_2">所有的模型都可调用，就像网络层一样</a></li>
        
            <li><a class="toctree-l4" href="#_3">多输入多输出模型</a></li>
        
            <li><a class="toctree-l4" href="#_4">共享网络层</a></li>
        
            <li><a class="toctree-l4" href="#_5">层「节点」的概念</a></li>
        
            <li><a class="toctree-l4" href="#_6">更多的例子</a></li>
        
        </ul>
    

    </ul>
                </li>
                <li class="">
                    
    <a class="" href="../faq/">FAQ 常见问题解答</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">模型</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../models/about-keras-models/">关于 Keras 模型</a>
                </li>
                <li class="">
                    
    <a class="" href="../../models/sequential/">Sequential 顺序模型 API</a>
                </li>
                <li class="">
                    
    <a class="" href="../../models/model/">函数式 API</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">Layers</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../layers/about-keras-layers/">关于 Keras 网络层</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/core/">核心网络层</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/convolutional/">卷积层 Convolutional</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/pooling/">池化层 Pooling</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/local/">局部连接层 Locally-connected</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/recurrent/">循环层 Recurrent</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/embeddings/">嵌入层 Embedding</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/merge/">融合层 Merge</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/advanced-activations/">高级激活层 Advanced Activations</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/normalization/">标准化层 Normalization</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/noise/">噪声层 Noise</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/wrappers/">层封装器 wrappers</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/writing-your-own-keras-layers/">编写你自己的层</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">数据预处理</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../preprocessing/sequence/">序列预处理</a>
                </li>
                <li class="">
                    
    <a class="" href="../../preprocessing/text/">文本预处理</a>
                </li>
                <li class="">
                    
    <a class="" href="../../preprocessing/image/">图像预处理</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../losses/">损失函数 Losses</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../metrics/">评估标准 Metrics</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../optimizers/">优化器 Optimizers</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../activations/">激活函数 Activations</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../callbacks/">回调函数 Callbacks</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../datasets/">常用数据集 Datasets</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../applications/">应用 Applications</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../backend/">后端 Backend</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../initializers/">初始化 Initializers</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../regularizers/">正则化 Regularizers</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../constraints/">约束 Constraints</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../visualization/">可视化 Visualization</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../scikit-learn-api/">Scikit-learn API</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../utils/">工具</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../contributing/">贡献</a>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">经典样例</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../examples/addition_rnn/">Addition RNN</a>
                </li>
                <li class="">
                    
    <a class="" href="../../examples/babi_rnn/">Baby RNN</a>
                </li>
                <li class="">
                    
    <a class="" href="../../examples/babi_memnn/">Baby MemNN</a>
                </li>
                <li class="">
                    
    <a class="" href="../../examples/cifar10_cnn/">CIFAR-10 CNN</a>
                </li>
                <li class="">
                    
    <a class="" href="../../examples/cifar10_cnn_capsule/">CIFAR-10 CNN-Capsule</a>
                </li>
                <li class="">
                    
    <a class="" href="../../examples/cifar10_cnn_tfaugment2d/">CIFAR-10 CNN with augmentation (TF)</a>
                </li>
                <li class="">
                    
    <a class="" href="../../examples/cifar10_resnet/">CIFAR-10 ResNet</a>
                </li>
                <li class="">
                    
    <a class="" href="../../examples/conv_filter_visualization/">Convolution filter visualization</a>
                </li>
                <li class="">
                    
    <a class="" href="../../examples/image_ocr/">Image OCR</a>
                </li>
                <li class="">
                    
    <a class="" href="../../examples/imdb_bidirectional_lstm/">Bidirectional LSTM</a>
                </li>
    </ul>
	    </li>
          
        </ul>
      </div>
      &nbsp;
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">

      
      <nav class="wy-nav-top" role="navigation" aria-label="top navigation">
        <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
        <a href="../..">Keras 中文文档</a>
      </nav>

      
      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="breadcrumbs navigation">
  <ul class="wy-breadcrumbs">
    <li><a href="../..">Docs</a> &raquo;</li>
    
      
        
          <li>快速开始 &raquo;</li>
        
      
    
    <li>函数式 API 指引</li>
    <li class="wy-breadcrumbs-aside">
      
        <a href="https://github.com/keras-team/keras-docs-zh/edit/master/docs/getting-started/functional-api-guide.md"
          class="icon icon-github"> Edit on GitHub</a>
      
    </li>
  </ul>
  <hr/>
</div>
          <div role="main">
            <div class="section">
              
                <h1 id="keras-api">开始使用 Keras 函数式 API</h1>
<p>Keras 函数式 API 是定义复杂模型（如多输出模型、有向无环图，或具有共享层的模型）的方法。</p>
<p>这部分文档假设你已经对 <code>Sequential</code> 顺序模型比较熟悉。</p>
<p>让我们先从一些简单的例子开始。</p>
<hr />
<h2 id="_1">例一：全连接网络</h2>
<p><code>Sequential</code> 模型可能是实现这种网络的一个更好选择，但这个例子能够帮助我们进行一些简单的理解。</p>
<ul>
<li>网络层的实例是可调用的，它以张量为参数，并且返回一个张量</li>
<li>输入和输出均为张量，它们都可以用来定义一个模型（<code>Model</code>）</li>
<li>这样的模型同 Keras 的 <code>Sequential</code> 模型一样，都可以被训练</li>
</ul>
<pre><code class="python">from keras.layers import Input, Dense
from keras.models import Model

# 这部分返回一个张量
inputs = Input(shape=(784,))

# 层的实例是可调用的，它以张量为参数，并且返回一个张量
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)

# 这部分创建了一个包含输入层和三个全连接层的模型
model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
model.fit(data, labels)  # 开始训练
</code></pre>

<hr />
<h2 id="_2">所有的模型都可调用，就像网络层一样</h2>
<p>利用函数式 API，可以轻易地重用训练好的模型：可以将任何模型看作是一个层，然后通过传递一个张量来调用它。注意，在调用模型时，您不仅重用模型的<em>结构</em>，还重用了它的权重。</p>
<pre><code class="python">x = Input(shape=(784,))
# 这是可行的，并且返回上面定义的 10-way softmax。
y = model(x)
</code></pre>

<p>这种方式能允许我们快速创建可以处理<em>序列输入</em>的模型。只需一行代码，你就将图像分类模型转换为视频分类模型。</p>
<pre><code class="python">from keras.layers import TimeDistributed

# 输入张量是 20 个时间步的序列，
# 每一个时间为一个 784 维的向量
input_sequences = Input(shape=(20, 784))

# 这部分将我们之前定义的模型应用于输入序列中的每个时间步。
# 之前定义的模型的输出是一个 10-way softmax，
# 因而下面的层的输出将是维度为 10 的 20 个向量的序列。
processed_sequences = TimeDistributed(model)(input_sequences)
</code></pre>

<hr />
<h2 id="_3">多输入多输出模型</h2>
<p>以下是函数式 API 的一个很好的例子：具有多个输入和输出的模型。函数式 API 使处理大量交织的数据流变得容易。</p>
<p>来考虑下面的模型。我们试图预测 Twitter 上的一条新闻标题有多少转发和点赞数。模型的主要输入将是新闻标题本身，即一系列词语，但是为了增添趣味，我们的模型还添加了其他的辅助输入来接收额外的数据，例如新闻标题的发布的时间等。
该模型也将通过两个损失函数进行监督学习。较早地在模型中使用主损失函数，是深度学习模型的一个良好正则方法。</p>
<p>模型结构如下图所示：</p>
<p><img src="/img/multi-input-multi-output-graph.png" alt="multi-input-multi-output-graph" style="width: 400px;"/></p>
<p>让我们用函数式 API 来实现它。</p>
<p>主要输入接收新闻标题本身，即一个整数序列（每个整数编码一个词）。
这些整数在 1 到 10,000 之间（10,000 个词的词汇表），且序列长度为 100 个词。</p>
<pre><code class="python">from keras.layers import Input, Embedding, LSTM, Dense
from keras.models import Model

# 标题输入：接收一个含有 100 个整数的序列，每个整数在 1 到 10000 之间。
# 注意我们可以通过传递一个 &quot;name&quot; 参数来命名任何层。
main_input = Input(shape=(100,), dtype='int32', name='main_input')

# Embedding 层将输入序列编码为一个稠密向量的序列，
# 每个向量维度为 512。
x = Embedding(output_dim=512, input_dim=10000, input_length=100)(main_input)

# LSTM 层把向量序列转换成单个向量，
# 它包含整个序列的上下文信息
lstm_out = LSTM(32)(x)
</code></pre>

<p>在这里，我们插入辅助损失，使得即使在模型主损失很高的情况下，LSTM 层和 Embedding 层都能被平稳地训练。</p>
<pre><code class="python">auxiliary_output = Dense(1, activation='sigmoid', name='aux_output')(lstm_out)
</code></pre>

<p>此时，我们将辅助输入数据与 LSTM 层的输出连接起来，输入到模型中：</p>
<pre><code class="python">auxiliary_input = Input(shape=(5,), name='aux_input')
x = keras.layers.concatenate([lstm_out, auxiliary_input])

# 堆叠多个全连接网络层
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)

# 最后添加主要的逻辑回归层
main_output = Dense(1, activation='sigmoid', name='main_output')(x)
</code></pre>

<p>然后定义一个具有两个输入和两个输出的模型：</p>
<pre><code class="python">model = Model(inputs=[main_input, auxiliary_input], outputs=[main_output, auxiliary_output])
</code></pre>

<p>现在编译模型，并给辅助损失分配一个 0.2 的权重。如果要为不同的输出指定不同的 <code>loss_weights</code> 或 <code>loss</code>，可以使用列表或字典。
在这里，我们给 <code>loss</code> 参数传递单个损失函数，这个损失将用于所有的输出。</p>
<pre><code class="python">model.compile(optimizer='rmsprop', loss='binary_crossentropy',
              loss_weights=[1., 0.2])
</code></pre>

<p>我们可以通过传递输入数组和目标数组的列表来训练模型：</p>
<pre><code class="python">model.fit([headline_data, additional_data], [labels, labels],
          epochs=50, batch_size=32)
</code></pre>

<p>由于输入和输出均被命名了（在定义时传递了一个 <code>name</code> 参数），我们也可以通过以下方式编译模型：</p>
<pre><code class="python">model.compile(optimizer='rmsprop',
              loss={'main_output': 'binary_crossentropy', 'aux_output': 'binary_crossentropy'},
              loss_weights={'main_output': 1., 'aux_output': 0.2})

# 然后使用以下方式训练：
model.fit({'main_input': headline_data, 'aux_input': additional_data},
          {'main_output': labels, 'aux_output': labels},
          epochs=50, batch_size=32)
</code></pre>

<hr />
<h2 id="_4">共享网络层</h2>
<p>函数式 API 的另一个用途是使用共享网络层的模型。我们来看看共享层。</p>
<p>来考虑推特推文数据集。我们想要建立一个模型来分辨两条推文是否来自同一个人（例如，通过推文的相似性来对用户进行比较）。</p>
<p>实现这个目标的一种方法是建立一个模型，将两条推文编码成两个向量，连接向量，然后添加逻辑回归层；这将输出两条推文来自同一作者的概率。模型将接收一对对正负表示的推特数据。</p>
<p>由于这个问题是对称的，编码第一条推文的机制应该被完全重用来编码第二条推文（权重及其他全部）。这里我们使用一个共享的 LSTM 层来编码推文。</p>
<p>让我们使用函数式 API 来构建它。首先我们将一条推特转换为一个尺寸为 <code>(280, 256)</code> 的矩阵，即每条推特 280 字符，每个字符为 256 维的 one-hot 编码向量 （取 256 个常用字符）。</p>
<pre><code class="python">import keras
from keras.layers import Input, LSTM, Dense
from keras.models import Model

tweet_a = Input(shape=(280, 256))
tweet_b = Input(shape=(280, 256))
</code></pre>

<p>要在不同的输入上共享同一个层，只需实例化该层一次，然后根据需要传入你想要的输入即可：</p>
<pre><code class="python"># 这一层可以输入一个矩阵，并返回一个 64 维的向量
shared_lstm = LSTM(64)

# 当我们重用相同的图层实例多次，图层的权重也会被重用 (它其实就是同一层)
encoded_a = shared_lstm(tweet_a)
encoded_b = shared_lstm(tweet_b)

# 然后再连接两个向量：
merged_vector = keras.layers.concatenate([encoded_a, encoded_b], axis=-1)

# 再在上面添加一个逻辑回归层
predictions = Dense(1, activation='sigmoid')(merged_vector)

# 定义一个连接推特输入和预测的可训练的模型
model = Model(inputs=[tweet_a, tweet_b], outputs=predictions)

model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])
model.fit([data_a, data_b], labels, epochs=10)
</code></pre>

<p>让我们暂停一会，看看如何读取共享层的输出或输出尺寸。</p>
<hr />
<h2 id="_5">层「节点」的概念</h2>
<p>每当你在某个输入上调用一个层时，都将创建一个新的张量（层的输出），并且为该层添加一个「节点」，将输入张量连接到输出张量。当多次调用同一个图层时，该图层将拥有多个节点索引 (0, 1, 2...)。</p>
<p>在之前版本的 Keras 中，可以通过 <code>layer.get_output()</code> 来获得层实例的输出张量，或者通过 <code>layer.output_shape</code> 来获取其输出形状。现在你依然可以这么做（除了 <code>get_output()</code> 已经被 <code>output</code> 属性替代）。但是如果一个层与多个输入连接呢？</p>
<p>只要一个层仅仅连接到一个输入，就不会有困惑，<code>.output</code> 会返回层的唯一输出：</p>
<pre><code class="python">a = Input(shape=(280, 256))

lstm = LSTM(32)
encoded_a = lstm(a)

assert lstm.output == encoded_a
</code></pre>

<p>但是如果该层有多个输入，那就会出现问题：</p>
<pre><code class="python">a = Input(shape=(280, 256))
b = Input(shape=(280, 256))

lstm = LSTM(32)
encoded_a = lstm(a)
encoded_b = lstm(b)

lstm.output
</code></pre>

<pre><code>&gt;&gt; AttributeError: Layer lstm_1 has multiple inbound nodes,
hence the notion of &quot;layer output&quot; is ill-defined.
Use `get_output_at(node_index)` instead.
</code></pre>

<p>好吧，通过下面的方法可以解决：</p>
<pre><code class="python">assert lstm.get_output_at(0) == encoded_a
assert lstm.get_output_at(1) == encoded_b
</code></pre>

<p>够简单，对吧？</p>
<p><code>input_shape</code> 和 <code>output_shape</code> 这两个属性也是如此：只要该层只有一个节点，或者只要所有节点具有相同的输入/输出尺寸，那么「层输出/输入尺寸」的概念就被很好地定义，并且将由 <code>layer.output_shape</code> / <code>layer.input_shape</code> 返回。但是比如说，如果将一个 <code>Conv2D</code> 层先应用于尺寸为 <code>(32，32，3)</code> 的输入，再应用于尺寸为 <code>(64, 64, 3)</code> 的输入，那么这个层就会有多个输入/输出尺寸，你将不得不通过指定它们所属节点的索引来获取它们：</p>
<pre><code class="python">a = Input(shape=(32, 32, 3))
b = Input(shape=(64, 64, 3))

conv = Conv2D(16, (3, 3), padding='same')
conved_a = conv(a)

# 到目前为止只有一个输入，以下可行：
assert conv.input_shape == (None, 32, 32, 3)

conved_b = conv(b)
# 现在 `.input_shape` 属性不可行，但是这样可以：
assert conv.get_input_shape_at(0) == (None, 32, 32, 3)
assert conv.get_input_shape_at(1) == (None, 64, 64, 3)
</code></pre>

<hr />
<h2 id="_6">更多的例子</h2>
<p>代码示例仍然是起步的最佳方式，所以这里还有更多的例子。</p>
<h3 id="inception">Inception 模型</h3>
<p>有关 Inception 结构的更多信息，请参阅 <a href="http://arxiv.org/abs/1409.4842">Going Deeper with Convolutions</a>。</p>
<pre><code class="python">from keras.layers import Conv2D, MaxPooling2D, Input

input_img = Input(shape=(256, 256, 3))

tower_1 = Conv2D(64, (1, 1), padding='same', activation='relu')(input_img)
tower_1 = Conv2D(64, (3, 3), padding='same', activation='relu')(tower_1)

tower_2 = Conv2D(64, (1, 1), padding='same', activation='relu')(input_img)
tower_2 = Conv2D(64, (5, 5), padding='same', activation='relu')(tower_2)

tower_3 = MaxPooling2D((3, 3), strides=(1, 1), padding='same')(input_img)
tower_3 = Conv2D(64, (1, 1), padding='same', activation='relu')(tower_3)

output = keras.layers.concatenate([tower_1, tower_2, tower_3], axis=1)
</code></pre>

<h3 id="_7">卷积层上的残差连接</h3>
<p>有关残差网络 (Residual Network) 的更多信息，请参阅 <a href="http://arxiv.org/abs/1512.03385">Deep Residual Learning for Image Recognition</a>。</p>
<pre><code class="python">from keras.layers import Conv2D, Input

# 输入张量为 3 通道 256x256 图像
x = Input(shape=(256, 256, 3))
# 3 输出通道（与输入通道相同）的 3x3 卷积核
y = Conv2D(3, (3, 3), padding='same')(x)
# 返回 x + y
z = keras.layers.add([x, y])
</code></pre>

<h3 id="_8">共享视觉模型</h3>
<p>该模型在两个输入上重复使用同一个图像处理模块，以判断两个 MNIST 数字是否为相同的数字。</p>
<pre><code class="python">from keras.layers import Conv2D, MaxPooling2D, Input, Dense, Flatten
from keras.models import Model

# 首先，定义视觉模型
digit_input = Input(shape=(27, 27, 1))
x = Conv2D(64, (3, 3))(digit_input)
x = Conv2D(64, (3, 3))(x)
x = MaxPooling2D((2, 2))(x)
out = Flatten()(x)

vision_model = Model(digit_input, out)

# 然后，定义区分数字的模型
digit_a = Input(shape=(27, 27, 1))
digit_b = Input(shape=(27, 27, 1))

# 视觉模型将被共享，包括权重和其他所有
out_a = vision_model(digit_a)
out_b = vision_model(digit_b)

concatenated = keras.layers.concatenate([out_a, out_b])
out = Dense(1, activation='sigmoid')(concatenated)

classification_model = Model([digit_a, digit_b], out)
</code></pre>

<h3 id="_9">视觉问答模型</h3>
<p>当被问及关于图片的自然语言问题时，该模型可以选择正确的单词作答。</p>
<p>它通过将问题和图像编码成向量，然后连接两者，在上面训练一个逻辑回归，来从词汇表中挑选一个可能的单词作答。</p>
<pre><code class="python">from keras.layers import Conv2D, MaxPooling2D, Flatten
from keras.layers import Input, LSTM, Embedding, Dense
from keras.models import Model, Sequential

# 首先，让我们用 Sequential 来定义一个视觉模型。
# 这个模型会把一张图像编码为向量。
vision_model = Sequential()
vision_model.add(Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(224, 224, 3)))
vision_model.add(Conv2D(64, (3, 3), activation='relu'))
vision_model.add(MaxPooling2D((2, 2)))
vision_model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
vision_model.add(Conv2D(128, (3, 3), activation='relu'))
vision_model.add(MaxPooling2D((2, 2)))
vision_model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
vision_model.add(Conv2D(256, (3, 3), activation='relu'))
vision_model.add(Conv2D(256, (3, 3), activation='relu'))
vision_model.add(MaxPooling2D((2, 2)))
vision_model.add(Flatten())

# 现在让我们用视觉模型来得到一个输出张量：
image_input = Input(shape=(224, 224, 3))
encoded_image = vision_model(image_input)

# 接下来，定义一个语言模型来将问题编码成一个向量。
# 每个问题最长 100 个词，词的索引从 1 到 9999.
question_input = Input(shape=(100,), dtype='int32')
embedded_question = Embedding(input_dim=10000, output_dim=256, input_length=100)(question_input)
encoded_question = LSTM(256)(embedded_question)

# 连接问题向量和图像向量：
merged = keras.layers.concatenate([encoded_question, encoded_image])

# 然后在上面训练一个 1000 词的逻辑回归模型：
output = Dense(1000, activation='softmax')(merged)

# 最终模型：
vqa_model = Model(inputs=[image_input, question_input], outputs=output)

# 下一步就是在真实数据上训练模型。
</code></pre>

<h3 id="_10">视频问答模型</h3>
<p>现在我们已经训练了图像问答模型，我们可以很快地将它转换为视频问答模型。在适当的训练下，你可以给它展示一小段视频（例如 100 帧的人体动作），然后问它一个关于这段视频的问题（例如，「这个人在做什么运动？」 -&gt; 「足球」）。</p>
<pre><code class="python">from keras.layers import TimeDistributed

video_input = Input(shape=(100, 224, 224, 3))
# 这是基于之前定义的视觉模型（权重被重用）构建的视频编码
encoded_frame_sequence = TimeDistributed(vision_model)(video_input)  # 输出为向量的序列
encoded_video = LSTM(256)(encoded_frame_sequence)  # 输出为一个向量

# 这是问题编码器的模型级表示，重复使用与之前相同的权重：
question_encoder = Model(inputs=question_input, outputs=encoded_question)

# 让我们用它来编码这个问题：
video_question_input = Input(shape=(100,), dtype='int32')
encoded_video_question = question_encoder(video_question_input)

# 这就是我们的视频问答模式：
merged = keras.layers.concatenate([encoded_video, encoded_video_question])
output = Dense(1000, activation='softmax')(merged)
video_qa_model = Model(inputs=[video_input, video_question_input], outputs=output)
</code></pre>
              
            </div>
          </div>
          <footer>
  
    <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
      
        <a href="../faq/" class="btn btn-neutral float-right" title="FAQ 常见问题解答">Next <span class="icon icon-circle-arrow-right"></span></a>
      
      
        <a href="../sequential-model-guide/" class="btn btn-neutral" title="Sequential 顺序模型指引"><span class="icon icon-circle-arrow-left"></span> Previous</a>
      
    </div>
  

  <hr/>

  <div role="contentinfo">
    <!-- Copyright etc -->
    
  </div>

  Built with <a href="http://www.mkdocs.org">MkDocs</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
      
        </div>
      </div>

    </section>

  </div>

  <div class="rst-versions" role="note" style="cursor: pointer">
    <span class="rst-current-version" data-toggle="rst-current-version">
      
          <a href="https://github.com/keras-team/keras-docs-zh/" class="fa fa-github" style="float: left; color: #fcfcfc"> GitHub</a>
      
      
        <span><a href="../sequential-model-guide/" style="color: #fcfcfc;">&laquo; Previous</a></span>
      
      
        <span style="margin-left: 15px"><a href="../faq/" style="color: #fcfcfc">Next &raquo;</a></span>
      
    </span>
</div>
    <script>var base_url = '../..';</script>
    <script src="../../js/theme.js" defer></script>
      <script src="../../search/main.js" defer></script>

</body>
</html>
